%matplotlib inline
%load_ext autoreload
%autoreload 2
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
sys.path.append(module_path)
from src import constants as c
from src import utils
from src import visualization as v
from src import model as m
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
'--data-dir=data/surgical_data/',
'--image-size=64',
'--loss-function=mmd'
])
datasets, dataloaders = utils.setup_data(args, augmentation=False)
load_model = True
model_name = "mmd_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/mmd_vae', model_name)
model = m.VAE(image_channels=args.image_channels,
image_size=args.image_size,
h_dim1=1024,
h_dim2=128,
zdim=args.z_dim).to(c.device)
model.load_state_dict(torch.load(model_path))
labels = pd.read_csv(os.path.join(args.root, args.data_dir, 'surgical_labels.csv'))
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['val'][1][0].numpy().transpose(1,2,0),
datasets['val'][9][0].numpy().transpose(1,2,0)]))
fig = plt.figure()
recon1, z, _, _ = model(datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, z, _, _ = model(datasets['val'][9][0].unsqueeze(0).to(c.device))
recon1 = utils.torch_to_image(recon1)
recon2 = utils.torch_to_image(recon2)
originals = np.hstack([utils.torch_to_image(datasets['val'][1][0]),
utils.torch_to_image(datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])
plt.imshow(np.vstack([originals, recons]))
images = v.latent_interpolation(datasets['val'][1][0],
datasets['val'][9][0],
model=model)
fig = v.plot_interpolation(images, "Interpolation\nBeta=5")
plt.savefig(os.path.join(args.root,
'data/mmd_vae',
'mmd_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)
a = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][1][0], model))[0]
b = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][9][0], model))[0]
diff = a-b
fig = plt.figure()
plt.plot(a)
plt.plot(b)
fig = plt.figure()
plt.plot(a-b)
for zdim in range(64):
images = v.explore_latent_dimension(datasets['val'][1][0], model, zdim=zdim)
fig = v.plot_interpolation(images, title='zdim {}'.format(zdim))
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['train'][360][0].numpy().transpose(1,2,0),
datasets['train'][368][0].numpy().transpose(1,2,0)]))
images = v.latent_interpolation(datasets['train'][360][0],
datasets['train'][368][0],
model=model)
fig = v.plot_interpolation(images, "Interpolation\nBeta=5")
plt.savefig(os.path.join(args.root,
'data/mmd_vae',
'mmd_tool_motion2.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)
for zdim in range(64):
images = v.explore_latent_dimension(datasets['train'][360][0], model, zdim=zdim)
fig = v.plot_interpolation(images, title='zdim {}'.format(zdim))